{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Privacy-Preserving Scientific Computing with NumPy in SPU\n",
    "\n",
    "NumPy is one of the most popular tool for scientific computing. It is so common that we could find lots of equivalents of NumPy in other languages like [xtensor](https://xtensor.readthedocs.io/en/latest/) and [Gonum](https://www.gonum.org/). So we can't help thinking whether we could express computation with NumPy-like APIs in privacy-preserving settings since everyone loves NumPy.\n",
    "\n",
    "\n",
    "Luckily, with the power of [JAX](https://jax.readthedocs.io/en/latest/) NumPy package, we could easily accomplish this goal. In this tutorial, we would go through:\n",
    "- The relation between JAX and SPU\n",
    "- Write a Jittable JAX Program\n",
    "- Execute JAX Program with SPU"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The relation between JAX and SPU\n",
    "\n",
    "### TL;DR\n",
    "\n",
    "SPU actually consists of two components - Compiler and Runtime. SPU Runtime could only execute [PPHlo](https://www.secretflow.org.cn/docs/spu/en/reference/pphlo_doc.html). One example of PPHlo kernel is [**pphlo.add**](https://www.secretflow.org.cn/docs/spu/en/reference/pphlo_doc.html#pphlo-add-mlir-pphlo-addop). Actually we just feed PPHlo programs to SPU Runtime directly to execute MPC programs in some internal applications when the logic is extremely simple and clear.\n",
    "\n",
    "SPU compiler could translate [XLA](https://www.tensorflow.org/xla) programs to [PPHlo](https://www.secretflow.org.cn/docs/spu/en/reference/pphlo_doc.html). You could check \"Supported\" XLA ops in [this documentation](https://www.secretflow.org.cn/docs/spu/en/reference/xla_status.html). You may find XLA ops are very similar to PPHlo ops in most cases. It seems we still couldn't write XLA programs by hand. You are absolutely correct. If you happen to check [here](https://www.tensorflow.org/xla#xla_frontends), you should  find actually there are lot's of AI frameworks which could emit XLA programs without your effort, including:\n",
    "\n",
    "- TensorFLow\n",
    "- Pytorch\n",
    "- JAX\n",
    "\n",
    "Let's go through each step to have a look at different programs!\n",
    "\n",
    "#### JAX Program\n",
    "\n",
    "The below is a jax program to add an array and a scalar. It should make sense to you if you are familiar with NumPy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Array([[5, 6],\n",
       "       [7, 8]], dtype=int32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def simple_add(x, y):\n",
    "    return jax.numpy.add(x, y)\n",
    "\n",
    "\n",
    "simple_add(np.array([[1, 2], [3, 4]]), 4)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### XLA Program\n",
    "\n",
    "Let's check what the XLA program for this JAX program looks like. JAX provides [xla_computation](https://jax.readthedocs.io/en/latest/_autosummary/jax.xla_computation.html) to convert JAX programs to XLA programs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'HloModule xla_computation_simple_add, entry_computation_layout={(s32[2,2]{1,0}, s32[])->(s32[2,2]{1,0})}\\n\\nENTRY main.6 {\\n  Arg_0.1 = s32[2,2]{1,0} parameter(0)\\n  Arg_1.2 = s32[] parameter(1)\\n  broadcast.3 = s32[2,2]{1,0} broadcast(Arg_1.2), dimensions={}\\n  add.4 = s32[2,2]{1,0} add(Arg_0.1, broadcast.3)\\n  ROOT tuple.5 = (s32[2,2]{1,0}) tuple(add.4)\\n}\\n\\n'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c = jax.xla_computation(simple_add)(np.array([[1, 2], [3, 4]]), 4)\n",
    "\n",
    "c.as_hlo_text()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You should be aware of the following facts:\n",
    "\n",
    "- shape and dtype is fixed in XLA program like **s32[2,2]{1,0}** in each command.\n",
    "- an implicit **broadcast** op is inserted before **add** op."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### PPHlo Program\n",
    "\n",
    "Lastly, let's check the PPHlo program for this XLA program. **spu.compile** could convert XLA programs to PPHlo programs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "b'module @xla_computation_simple_add attributes {mhlo.cross_program_prefetches = [], mhlo.dynamic_parameter_bindings = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {\\n  func.func @main(%arg0: tensor<2x2x!pphlo.sec<i32>>, %arg1: tensor<!pphlo.sec<i32>>) -> tensor<2x2x!pphlo.sec<i32>> {\\n    %0 = \"pphlo.broadcast\"(%arg1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<!pphlo.sec<i32>>) -> tensor<2x2x!pphlo.sec<i32>>\\n    %1 = \"pphlo.add\"(%arg0, %0) : (tensor<2x2x!pphlo.sec<i32>>, tensor<2x2x!pphlo.sec<i32>>) -> tensor<2x2x!pphlo.sec<i32>>\\n    return %1 : tensor<2x2x!pphlo.sec<i32>>\\n  }\\n}\\n'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import spu\n",
    "import spu.spu_pb2 as spu_pb2\n",
    "\n",
    "source = spu_pb2.CompilationSource()\n",
    "source.ir_txt = c.as_serialized_hlo_module_proto()\n",
    "source.input_visibility.extend([spu.Visibility.VIS_SECRET, spu.Visibility.VIS_SECRET])\n",
    "source.ir_type = spu_pb2.SourceIRType.XLA\n",
    "\n",
    "pphlo = spu.compile(source, spu_pb2.CompilerOptions())\n",
    "\n",
    "pphlo"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You may find the PPHlo program is identical to XLA program. The only differences are:\n",
    "\n",
    "- You have to provide the input visibility to SPU compiler, i.e. **[spu.Visibility.VIS_SECRET, spu.Visibility.VIS_SECRET]** in our case.\n",
    "- Comparing to XLA program, **Visibility** is an extra attribute to all variables in PPHlo program like **tensor<2x2x!pphlo.sec<i32>>** means this is a secret 2x2 i32 tensor."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SPU compiler would deduce visibility in each step, let's modify input visibility and check what would happen."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "b'module @xla_computation_simple_add attributes {mhlo.cross_program_prefetches = [], mhlo.dynamic_parameter_bindings = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {\\n  func.func @main(%arg0: tensor<2x2x!pphlo.sec<i32>>, %arg1: tensor<!pphlo.pub<i32>>) -> tensor<2x2x!pphlo.sec<i32>> {\\n    %0 = \"pphlo.broadcast\"(%arg1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<!pphlo.pub<i32>>) -> tensor<2x2x!pphlo.pub<i32>>\\n    %1 = \"pphlo.add\"(%arg0, %0) : (tensor<2x2x!pphlo.sec<i32>>, tensor<2x2x!pphlo.pub<i32>>) -> tensor<2x2x!pphlo.sec<i32>>\\n    return %1 : tensor<2x2x!pphlo.sec<i32>>\\n  }\\n}\\n'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "source = spu_pb2.CompilationSource()\n",
    "source.ir_txt = c.as_serialized_hlo_module_proto()\n",
    "source.input_visibility.extend([spu.Visibility.VIS_SECRET, spu.Visibility.VIS_PUBLIC])\n",
    "source.ir_type = spu_pb2.SourceIRType.XLA\n",
    "\n",
    "pphlo = spu.compile(source, spu_pb2.CompilerOptions())\n",
    "\n",
    "pphlo"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### From JAX to SPU\n",
    "\n",
    "So this is the whole story. \n",
    "1. You write a JAX program in Python. \n",
    "2. Then you could turn JAX program to XLA program with the first-party API from JAX, i.e. jax.xla_computation. \n",
    "3. Afterwards,  SPU compiler could transfer XLA program to PPHlo program - the only language could be understood by SPU Runtime. \n",
    "4. In the end, the PPHlo program is sent to SPU Runtimes and executed.\n",
    "\n",
    "In SecretFlow, we have implemented some helper methods so that you could just write a JAX program in the beginning, we would take care of the remaining steps for you."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Write a Jittable JAX Program\n",
    "\n",
    "Jittable means a JAX program could be Just In Time (JIT) compilation into XLA program. So only when a JAX program is Jittable, it then could be possibly executed by SPU.\n",
    "\n",
    "Since SPU doesn't support all XLA operators, even a JAX program is jittable, SPU runtime still could refuse to execute. \n",
    "\n",
    "### JAX NumPy Package\n",
    "\n",
    "We could use these [NumPy-like APIs](https://jax.readthedocs.io/en/latest/jax.numpy.html) from JAX. JAX NumPy APIs are very similar to original ones, while\n",
    "- JAX NumPy arrays are immutable, so you have to use **ndarray.at** instead of in-place array updates\n",
    "- You have to provide some extra args to make the method call jittable(we would discuss this later).\n",
    "\n",
    "And actually SPU doesn't support all JAX NumPy operators, please also check [this documentation](http://www.secretflow.org.cn/docs/spu/en/reference/np_op_status.html). We are updating this document promptly and we have listed the current status of each operators.\n",
    "\n",
    "Next, we are going to write some JAX Numpy programs."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Euclidean Distance\n",
    "\n",
    "Just one-line code we could compute Euclidean Distance of two points."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def euclidean_distance(p1, p2):\n",
    "    return jax.numpy.linalg.norm(p1 - p2)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's check whether it is jittable by **jax.jit**. You could also use **jax.xla_computation** for testing as well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5.0\n",
      "HloModule xla_computation_euclidean_distance, entry_computation_layout={(s32[2]{0},s32[2]{0})->(f32[])}\n",
      "\n",
      "region_0.4 {\n",
      "  Arg_0.5 = f32[] parameter(0)\n",
      "  Arg_1.6 = f32[] parameter(1)\n",
      "  ROOT add.7 = f32[] add(Arg_0.5, Arg_1.6)\n",
      "}\n",
      "\n",
      "norm.8 {\n",
      "  Arg_0.9 = s32[2]{0} parameter(0)\n",
      "  convert.11 = f32[2]{0} convert(Arg_0.9)\n",
      "  multiply.12 = f32[2]{0} multiply(convert.11, convert.11)\n",
      "  constant.10 = f32[] constant(0)\n",
      "  reduce.13 = f32[] reduce(multiply.12, constant.10), dimensions={0}, to_apply=region_0.4\n",
      "  ROOT sqrt.14 = f32[] sqrt(reduce.13)\n",
      "}\n",
      "\n",
      "ENTRY main.17 {\n",
      "  Arg_0.1 = s32[2]{0} parameter(0)\n",
      "  Arg_1.2 = s32[2]{0} parameter(1)\n",
      "  subtract.3 = s32[2]{0} subtract(Arg_0.1, Arg_1.2)\n",
      "  call.15 = f32[] call(subtract.3), to_apply=norm.8\n",
      "  ROOT tuple.16 = (f32[]) tuple(call.15)\n",
      "}\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "euclidean_distance_jit = jax.jit(euclidean_distance)\n",
    "\n",
    "print(euclidean_distance_jit(np.array([0, 0]), np.array([3, 4])))\n",
    "\n",
    "\n",
    "# or\n",
    "print(\n",
    "    (\n",
    "        jax.xla_computation(euclidean_distance)(np.array([0, 0]), np.array([3, 4]))\n",
    "    ).as_hlo_text()\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Area of a Simple Polygon\n",
    "\n",
    "Given a list of Cartesian coordinates of vertices of a simply polygon, we could calculate the area by [Shoelace formula](https://en.wikipedia.org/wiki/Shoelace_formula)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(16.5, dtype=float32)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax.numpy as jnp\n",
    "\n",
    "\n",
    "def area_of_simple_polygon(vertices):\n",
    "    area = 0\n",
    "    for i in range(0, vertices.shape[0]):\n",
    "        a = jnp.expand_dims(vertices[i, :], axis=0)\n",
    "        b = jnp.expand_dims(vertices[(i + 1) % vertices.shape[0], :], axis=0)\n",
    "        x = jax.numpy.concatenate((a, b))\n",
    "        x_t = jax.numpy.transpose(x)\n",
    "        area += 0.5 * jax.numpy.linalg.det(x_t)\n",
    "    return area\n",
    "\n",
    "\n",
    "vertices = np.array([[1, 6], [3, 1], [7, 2], [4, 4], [8, 5]])\n",
    "\n",
    "area_of_simple_polygon(vertices)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's check whether **area_of_simple_polygon** is jittable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "HloModule xla_computation_area_of_simple_polygon, entry_computation_layout={(s32[5,2]{1,0})->(f32[])}\n",
      "\n",
      "det.7 {\n",
      "  Arg_0.8 = s32[2,2]{1,0} parameter(0)\n",
      "  convert.9 = f32[2,2]{1,0} convert(Arg_0.8)\n",
      "  slice.10 = f32[1,1]{1,0} slice(convert.9), slice={[0:1], [0:1]}\n",
      "  reshape.11 = f32[] reshape(slice.10)\n",
      "  slice.12 = f32[1,1]{1,0} slice(convert.9), slice={[1:2], [1:2]}\n",
      "  reshape.13 = f32[] reshape(slice.12)\n",
      "  multiply.14 = f32[] multiply(reshape.11, reshape.13)\n",
      "  slice.15 = f32[1,1]{1,0} slice(convert.9), slice={[0:1], [1:2]}\n",
      "  reshape.16 = f32[] reshape(slice.15)\n",
      "  slice.17 = f32[1,1]{1,0} slice(convert.9), slice={[1:2], [0:1]}\n",
      "  reshape.18 = f32[] reshape(slice.17)\n",
      "  multiply.19 = f32[] multiply(reshape.16, reshape.18)\n",
      "  ROOT subtract.20 = f32[] subtract(multiply.14, multiply.19)\n",
      "}\n",
      "\n",
      "det_0.27 {\n",
      "  Arg_0.28 = s32[2,2]{1,0} parameter(0)\n",
      "  convert.29 = f32[2,2]{1,0} convert(Arg_0.28)\n",
      "  slice.30 = f32[1,1]{1,0} slice(convert.29), slice={[0:1], [0:1]}\n",
      "  reshape.31 = f32[] reshape(slice.30)\n",
      "  slice.32 = f32[1,1]{1,0} slice(convert.29), slice={[1:2], [1:2]}\n",
      "  reshape.33 = f32[] reshape(slice.32)\n",
      "  multiply.34 = f32[] multiply(reshape.31, reshape.33)\n",
      "  slice.35 = f32[1,1]{1,0} slice(convert.29), slice={[0:1], [1:2]}\n",
      "  reshape.36 = f32[] reshape(slice.35)\n",
      "  slice.37 = f32[1,1]{1,0} slice(convert.29), slice={[1:2], [0:1]}\n",
      "  reshape.38 = f32[] reshape(slice.37)\n",
      "  multiply.39 = f32[] multiply(reshape.36, reshape.38)\n",
      "  ROOT subtract.40 = f32[] subtract(multiply.34, multiply.39)\n",
      "}\n",
      "\n",
      "det_1.48 {\n",
      "  Arg_0.49 = s32[2,2]{1,0} parameter(0)\n",
      "  convert.50 = f32[2,2]{1,0} convert(Arg_0.49)\n",
      "  slice.51 = f32[1,1]{1,0} slice(convert.50), slice={[0:1], [0:1]}\n",
      "  reshape.52 = f32[] reshape(slice.51)\n",
      "  slice.53 = f32[1,1]{1,0} slice(convert.50), slice={[1:2], [1:2]}\n",
      "  reshape.54 = f32[] reshape(slice.53)\n",
      "  multiply.55 = f32[] multiply(reshape.52, reshape.54)\n",
      "  slice.56 = f32[1,1]{1,0} slice(convert.50), slice={[0:1], [1:2]}\n",
      "  reshape.57 = f32[] reshape(slice.56)\n",
      "  slice.58 = f32[1,1]{1,0} slice(convert.50), slice={[1:2], [0:1]}\n",
      "  reshape.59 = f32[] reshape(slice.58)\n",
      "  multiply.60 = f32[] multiply(reshape.57, reshape.59)\n",
      "  ROOT subtract.61 = f32[] subtract(multiply.55, multiply.60)\n",
      "}\n",
      "\n",
      "det_2.69 {\n",
      "  Arg_0.70 = s32[2,2]{1,0} parameter(0)\n",
      "  convert.71 = f32[2,2]{1,0} convert(Arg_0.70)\n",
      "  slice.72 = f32[1,1]{1,0} slice(convert.71), slice={[0:1], [0:1]}\n",
      "  reshape.73 = f32[] reshape(slice.72)\n",
      "  slice.74 = f32[1,1]{1,0} slice(convert.71), slice={[1:2], [1:2]}\n",
      "  reshape.75 = f32[] reshape(slice.74)\n",
      "  multiply.76 = f32[] multiply(reshape.73, reshape.75)\n",
      "  slice.77 = f32[1,1]{1,0} slice(convert.71), slice={[0:1], [1:2]}\n",
      "  reshape.78 = f32[] reshape(slice.77)\n",
      "  slice.79 = f32[1,1]{1,0} slice(convert.71), slice={[1:2], [0:1]}\n",
      "  reshape.80 = f32[] reshape(slice.79)\n",
      "  multiply.81 = f32[] multiply(reshape.78, reshape.80)\n",
      "  ROOT subtract.82 = f32[] subtract(multiply.76, multiply.81)\n",
      "}\n",
      "\n",
      "det_3.90 {\n",
      "  Arg_0.91 = s32[2,2]{1,0} parameter(0)\n",
      "  convert.92 = f32[2,2]{1,0} convert(Arg_0.91)\n",
      "  slice.93 = f32[1,1]{1,0} slice(convert.92), slice={[0:1], [0:1]}\n",
      "  reshape.94 = f32[] reshape(slice.93)\n",
      "  slice.95 = f32[1,1]{1,0} slice(convert.92), slice={[1:2], [1:2]}\n",
      "  reshape.96 = f32[] reshape(slice.95)\n",
      "  multiply.97 = f32[] multiply(reshape.94, reshape.96)\n",
      "  slice.98 = f32[1,1]{1,0} slice(convert.92), slice={[0:1], [1:2]}\n",
      "  reshape.99 = f32[] reshape(slice.98)\n",
      "  slice.100 = f32[1,1]{1,0} slice(convert.92), slice={[1:2], [0:1]}\n",
      "  reshape.101 = f32[] reshape(slice.100)\n",
      "  multiply.102 = f32[] multiply(reshape.99, reshape.101)\n",
      "  ROOT subtract.103 = f32[] subtract(multiply.97, multiply.102)\n",
      "}\n",
      "\n",
      "ENTRY main.108 {\n",
      "  Arg_0.1 = s32[5,2]{1,0} parameter(0)\n",
      "  slice.3 = s32[1,2]{1,0} slice(Arg_0.1), slice={[0:1], [0:2]}\n",
      "  slice.4 = s32[1,2]{1,0} slice(Arg_0.1), slice={[1:2], [0:2]}\n",
      "  concatenate.5 = s32[2,2]{1,0} concatenate(slice.3, slice.4), dimensions={0}\n",
      "  transpose.6 = s32[2,2]{0,1} transpose(concatenate.5), dimensions={1,0}\n",
      "  call.21 = f32[] call(transpose.6), to_apply=det.7\n",
      "  constant.2 = f32[] constant(0.5)\n",
      "  multiply.22 = f32[] multiply(call.21, constant.2)\n",
      "  slice.23 = s32[1,2]{1,0} slice(Arg_0.1), slice={[1:2], [0:2]}\n",
      "  slice.24 = s32[1,2]{1,0} slice(Arg_0.1), slice={[2:3], [0:2]}\n",
      "  concatenate.25 = s32[2,2]{1,0} concatenate(slice.23, slice.24), dimensions={0}\n",
      "  transpose.26 = s32[2,2]{0,1} transpose(concatenate.25), dimensions={1,0}\n",
      "  call.41 = f32[] call(transpose.26), to_apply=det_0.27\n",
      "  multiply.42 = f32[] multiply(call.41, constant.2)\n",
      "  add.43 = f32[] add(multiply.22, multiply.42)\n",
      "  slice.44 = s32[1,2]{1,0} slice(Arg_0.1), slice={[2:3], [0:2]}\n",
      "  slice.45 = s32[1,2]{1,0} slice(Arg_0.1), slice={[3:4], [0:2]}\n",
      "  concatenate.46 = s32[2,2]{1,0} concatenate(slice.44, slice.45), dimensions={0}\n",
      "  transpose.47 = s32[2,2]{0,1} transpose(concatenate.46), dimensions={1,0}\n",
      "  call.62 = f32[] call(transpose.47), to_apply=det_1.48\n",
      "  multiply.63 = f32[] multiply(call.62, constant.2)\n",
      "  add.64 = f32[] add(add.43, multiply.63)\n",
      "  slice.65 = s32[1,2]{1,0} slice(Arg_0.1), slice={[3:4], [0:2]}\n",
      "  slice.66 = s32[1,2]{1,0} slice(Arg_0.1), slice={[4:5], [0:2]}\n",
      "  concatenate.67 = s32[2,2]{1,0} concatenate(slice.65, slice.66), dimensions={0}\n",
      "  transpose.68 = s32[2,2]{0,1} transpose(concatenate.67), dimensions={1,0}\n",
      "  call.83 = f32[] call(transpose.68), to_apply=det_2.69\n",
      "  multiply.84 = f32[] multiply(call.83, constant.2)\n",
      "  add.85 = f32[] add(add.64, multiply.84)\n",
      "  slice.86 = s32[1,2]{1,0} slice(Arg_0.1), slice={[4:5], [0:2]}\n",
      "  slice.87 = s32[1,2]{1,0} slice(Arg_0.1), slice={[0:1], [0:2]}\n",
      "  concatenate.88 = s32[2,2]{1,0} concatenate(slice.86, slice.87), dimensions={0}\n",
      "  transpose.89 = s32[2,2]{0,1} transpose(concatenate.88), dimensions={1,0}\n",
      "  call.104 = f32[] call(transpose.89), to_apply=det_3.90\n",
      "  multiply.105 = f32[] multiply(call.104, constant.2)\n",
      "  add.106 = f32[] add(add.85, multiply.105)\n",
      "  ROOT tuple.107 = (f32[]) tuple(add.106)\n",
      "}\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(jax.xla_computation(area_of_simple_polygon)(vertices).as_hlo_text())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Could We Jit Anything?\n",
    "\n",
    "Absolutely not, please check [this documentation](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html#why-can-t-we-just-jit-everything) from JAX!\n",
    "\n",
    "The most common cause to unjittable program is your control flow relies on the value of **input**. For instance,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "  File \"/tmp/ipykernel_2314574/682526420.py\", line 14, in <module>\n",
      "    g_jit(10, 20)  # Should raise an error.\n",
      "  File \"/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/traceback_util.py\", line 163, in reraise_with_filtered_traceback\n",
      "    return fun(*args, **kwargs)\n",
      "  File \"/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py\", line 235, in cache_miss\n",
      "    outs, out_flat, out_tree, args_flat = _python_pjit_helper(\n",
      "  File \"/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py\", line 179, in _python_pjit_helper\n",
      "    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(\n",
      "  File \"/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/api.py\", line 440, in infer_params\n",
      "    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)\n",
      "  File \"/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py\", line 513, in common_infer_params\n",
      "    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(\n",
      "  File \"/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py\", line 965, in _pjit_jaxpr\n",
      "    jaxpr, final_consts, global_out_avals = _create_pjit_jaxpr(\n",
      "  File \"/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/linear_util.py\", line 301, in memoized_fun\n",
      "    ans = call(fun, *args)\n",
      "  File \"/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py\", line 923, in _create_pjit_jaxpr\n",
      "    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(\n",
      "  File \"/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/profiler.py\", line 314, in wrapper\n",
      "    return func(*args, **kwargs)\n",
      "  File \"/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/interpreters/partial_eval.py\", line 2033, in trace_to_jaxpr_dynamic\n",
      "    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(\n",
      "  File \"/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/interpreters/partial_eval.py\", line 2050, in trace_to_subjaxpr_dynamic\n",
      "    ans = fun.call_wrapped(*in_tracers_)\n",
      "  File \"/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/linear_util.py\", line 165, in call_wrapped\n",
      "    ans = self.f(*args, **dict(self.params, **kwargs))\n",
      "  File \"/tmp/ipykernel_2314574/682526420.py\", line 6, in g\n",
      "    while i < n:\n",
      "  File \"/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/core.py\", line 653, in __bool__\n",
      "    def __bool__(self): return self.aval._bool(self)\n",
      "  File \"/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/core.py\", line 1344, in error\n",
      "    raise ConcretizationTypeError(arg, fname_context)\n",
      "jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>\n",
      "The problem arose with the `bool` function. \n",
      "The error occurred while tracing the function g at /tmp/ipykernel_2314574/682526420.py:4 for jit. This concrete value was not available in Python because it depends on the value of the argument 'n'.\n",
      "\n",
      "See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n"
     ]
    }
   ],
   "source": [
    "# Cited from JAX documentation.\n",
    "# While loop conditioned on x and n.\n",
    "\n",
    "\n",
    "def g(x, n):\n",
    "    i = 0\n",
    "    while i < n:\n",
    "        i += 1\n",
    "    return x + i\n",
    "\n",
    "\n",
    "g_jit = jax.jit(g)\n",
    "\n",
    "import traceback\n",
    "\n",
    "try:\n",
    "    g_jit(10, 20)  # Should raise an error.\n",
    "except Exception:\n",
    "    traceback.print_exc()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are two possible solutions.\n",
    "1. You could replace control flow with [low-level **jax.lax** APIs](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators). You need to spend some time figure out how to use these APIs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(30, dtype=int32, weak_type=True)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def g_with_lax_control_flow(x, n):\n",
    "    def body_fun(i):\n",
    "        i += 1\n",
    "        return i\n",
    "\n",
    "    return x + jax.lax.while_loop(lambda i: i < n, body_fun, 0)\n",
    "\n",
    "\n",
    "g_with_lax_control_flow_jit = jax.jit(g_with_lax_control_flow)\n",
    "g_with_lax_control_flow_jit(10, 20)  # good to go!"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. The other possible solution is to use **static_argnames**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(30, dtype=int32, weak_type=True)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g_with_static_argnames_jit = jax.jit(g, static_argnames=['n'])\n",
    "g_with_static_argnames_jit(10, 20)  # good to go!"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "so which method we should choose when the program is unjittable?\n",
    "\n",
    "This is our suggestion:\n",
    "\n",
    "- Rewrite the control flow with **jax.lax** APIs first. Although these are some learning costs here, but it deserves that.\n",
    "- If the visibility of affected input values are **VIS_PUBLIC** like **n** in the above example, you could mark it as **static_argnames** and these affected input values would be compiled into XLA program."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### More Examples\n",
    "\n",
    "If you would like to check more examples, please check [Python examples](https://github.com/secretflow/spu/tree/main/examples/python) in SPU repo. In most examples, the MPC part are written with **jax.numpy** package. And you could find we are using **jax.lax** APIs and **static_argnames** heavily to make JAX program jittable!"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Execute JAX Program with SPU\n",
    "\n",
    "Once you have your jittable JAX program ready, we could execute them with SPU!\n",
    "\n",
    "\n",
    "### (Optional) SPU Simulation\n",
    "\n",
    "\n",
    "If you hope to get a quick try, I would like to introduce **spu.sim_jax** to you. Let's show how does it work!\n",
    "\n",
    "\n",
    "> **spu.sim_jax** is only exposed after **spu v0.3.1b8**.\n",
    "\n",
    "Here we create an SPU simulator with the following settings:\n",
    "- world size of 3.\n",
    "- with ABY3 protocol. Check all supported protocol [here](http://www.secretflow.org.cn/docs/spu/en/reference/mpc_status.html#supported-mpc-protocol).\n",
    "- field of 64 which the values in SPU are expressed in 2^64 ring.\n",
    "\n",
    "However, if you just want to confirm if the JAX program could be executed by SPU, any settings should be fine. Different settings could only affect the elapsed time and precision of computation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(4.999962, dtype=float32)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from spu.utils.simulation import Simulator, sim_jax\n",
    "\n",
    "sim = Simulator.simple(3, spu.ProtocolKind.ABY3, spu.FieldType.FM64)\n",
    "\n",
    "spu_euclidean_distance_fn = sim_jax(sim, euclidean_distance)\n",
    "\n",
    "spu_euclidean_distance_fn(np.array([0, 0]), np.array([3, 4]))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you execute the code above repeatedly, you may find the result is slightly different between runs, which is expected due to randomness in MPC computation.\n",
    "\n",
    "After testing with **euclidean_distance**, we have a try with **area_of_simple_polygon**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(16.5, dtype=float32)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "spu_area_of_simple_polygon_fn = sim_jax(sim, area_of_simple_polygon)\n",
    "\n",
    "spu_area_of_simple_polygon_fn(vertices)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run with SPU Device\n",
    "\n",
    "Finally, we are going to run the JAX program with SecretFlow.\n",
    "\n",
    "I guess you should be familiar with the following steps if you have checked out other tutorials.\n",
    "\n",
    "Here we create a local standalone SecretFlow cluster with three devices:\n",
    "\n",
    "- Two PYU device - **alice** and **bob**\n",
    "- An SPU device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-03-21 21:03:27,671\tINFO worker.py:1538 -- Started a local Ray instance.\n"
     ]
    }
   ],
   "source": [
    "import secretflow as sf\n",
    "\n",
    "# Check the version of your SecretFlow\n",
    "print('The version of SecretFlow: {}'.format(sf.__version__))\n",
    "\n",
    "# In case you have a running secretflow runtime already.\n",
    "sf.shutdown()\n",
    "\n",
    "sf.init(parties=['alice', 'bob'], address='local')\n",
    "\n",
    "alice, bob = sf.PYU('alice'), sf.PYU('bob')\n",
    "spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We try **euclidean_distance** with spu device first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(_run pid=2316721)\u001b[0m INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: \n",
      "\u001b[2m\u001b[36m(_run pid=2316721)\u001b[0m INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n",
      "\u001b[2m\u001b[36m(_run pid=2316721)\u001b[0m INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n",
      "\u001b[2m\u001b[36m(_run pid=2316721)\u001b[0m INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.\n",
      "\u001b[2m\u001b[36m(_run pid=2316721)\u001b[0m INFO:jax._src.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.\n",
      "\u001b[2m\u001b[36m(_run pid=2316721)\u001b[0m WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
      "\u001b[2m\u001b[36m(_run pid=2316358)\u001b[0m INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: \n",
      "\u001b[2m\u001b[36m(_run pid=2316358)\u001b[0m INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n",
      "\u001b[2m\u001b[36m(_run pid=2316358)\u001b[0m INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n",
      "\u001b[2m\u001b[36m(_run pid=2316358)\u001b[0m INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.\n",
      "\u001b[2m\u001b[36m(_run pid=2316358)\u001b[0m INFO:jax._src.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.\n",
      "\u001b[2m\u001b[36m(_run pid=2316358)\u001b[0m WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array(5., dtype=float32)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p1 = sf.to(alice, np.array([0, 0]))\n",
    "p2 = sf.to(bob, np.array([3, 4]))\n",
    "\n",
    "distance = spu(euclidean_distance)(p1, p2)\n",
    "\n",
    "sf.reveal(distance)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we try **area_of_simple_polygon**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(16.5, dtype=float32)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v = sf.to(alice, vertices)\n",
    "area = spu(area_of_simple_polygon)(v)\n",
    "\n",
    "sf.reveal(area)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "This is the end of the tutorial. Let's summarize the steps to do privacy-preserving scientific computation with JAX NumPy APIS:\n",
    "\n",
    "1. Write a jittable JAX NumPy program. You should test it with **jax.jit** or **jax.xla_computation**.\n",
    "2. (Optional) Try the JAX program with **SPU simulation**.\n",
    "3. Run this JAX NumPy with SPU device in SecretFlow.\n",
    "\n",
    "\n",
    "If you find your JAX program is jittable but fails with SPU compiler or runtime. Please check [JAX NumPy Operators Status](http://www.secretflow.org.cn/docs/spu/en/reference/np_op_status.html) and [XLA Implementation Status](http://www.secretflow.org.cn/docs/spu/en/reference/xla_status.html). Or you could contact us directly with [GitHub Issues](https://github.com/secretflow/spu/issues)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "vscode": {
   "interpreter": {
    "hash": "132606ca73fb74a80a732f3bad8b629a204996f9ecf3f30a045620700a49c15e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
